use std::sync::{atomic::AtomicU32, Arc, LazyLock};
use thiserror::Error;

use byteorder::{BE, LE};
use common::{
    cryptography::{Ec2b, MhyXorpad},
    util,
};
use proto::{CmdID, NapMessage, PacketHead};
use tokio::{
    io::AsyncWriteExt,
    net::{
        tcp::{OwnedReadHalf, OwnedWriteHalf},
        TcpStream,
    },
    sync::{Mutex, OnceCell},
};

use crate::{
    handlers::{self, PacketHandlingError},
    ServerState,
};

use super::packet::DecodeError;
use super::NetPacket;

static SECRET_KEY: LazyLock<MhyXorpad> = LazyLock::new(|| {
    let ec2b = Ec2b::read(&mut util::open_secret_key().expect("Failed to open secret key file"))
        .expect("Failed to read Ec2b data");

    MhyXorpad::new::<LE>(ec2b.derive_seed())
});

pub struct NetSession {
    id: u64,
    reader: Mutex<OwnedReadHalf>,
    writer: Mutex<OwnedWriteHalf>,
    session_key: OnceCell<MhyXorpad>,
    packet_id_counter: AtomicU32,
    state: AtomicNetSessionState,
    account_uid: OnceCell<String>,
    player_uid: OnceCell<u32>,
}

#[atomic_enum::atomic_enum]
#[derive(PartialEq, Eq, PartialOrd)]
pub enum NetSessionState {
    StartEnterGameWorld,
    PlayerGetTokenCsReq,
    PlayerGetTokenScRsp,
    PlayerLoginCsReq,
    PlayerLoginScRsp,
    StartBasicsReq,
    EndBasicsReq,
    EnterWorldScRsp,
}

impl NetSessionState {
    pub fn is_command_allowed(&self, cmd_id: u16) -> bool {
        match cmd_id {
            proto::PlayerGetTokenCsReq::CMD_ID => *self == NetSessionState::StartEnterGameWorld,
            proto::PlayerLoginCsReq::CMD_ID => *self == NetSessionState::PlayerGetTokenScRsp,
            _ => *self >= NetSessionState::StartBasicsReq,
        }
    }

    pub fn is_auth(&self) -> bool {
        *self < NetSessionState::PlayerLoginScRsp
    }
}

#[derive(Error, Debug)]
pub enum SessionError {
    #[error("NetPacket decode failed: {0}")]
    PacketDecode(#[from] DecodeError),
    #[error("failed to handle packet: {0}")]
    PacketHandling(#[from] PacketHandlingError),
}

impl NetSession {
    pub fn new(id: u64, stream: TcpStream) -> Self {
        let (reader, writer) = stream.into_split();

        Self {
            id,
            reader: Mutex::new(reader),
            writer: Mutex::new(writer),
            session_key: OnceCell::new(),
            packet_id_counter: AtomicU32::new(0),
            state: AtomicNetSessionState::new(NetSessionState::StartEnterGameWorld),
            account_uid: OnceCell::new(),
            player_uid: OnceCell::new(),
        }
    }

    pub async fn run(&self, state: Arc<ServerState>) -> Result<(), SessionError> {
        let mut last_save_time = util::cur_timestamp();

        let result = loop {
            let packet = match NetPacket::read(&mut *self.reader.lock().await).await {
                Ok(packet) => packet,
                Err(DecodeError::IoError(_)) => break Ok(()),
                Err(err) => break Err(SessionError::PacketDecode(err)),
            };

            match self.handle_packet(packet, &state).await {
                Ok(()) => (),
                Err(PacketHandlingError::Logout) => break Ok(()),
                Err(err) => break Err(SessionError::PacketHandling(err)),
            }

            if let Some(uid) = self.player_uid.get() {
                if (util::cur_timestamp() - last_save_time)
                    >= state.config.player_save_period_seconds
                {
                    state.player_mgr.save(*uid).await;
                    last_save_time = util::cur_timestamp();
                }
            }
        };

        self.on_disconnect(&state).await;
        result
    }

    async fn handle_packet(
        &self,
        mut packet: NetPacket,
        state: &ServerState,
    ) -> Result<(), PacketHandlingError> {
        self.xor_payload(packet.cmd_id, &mut packet.body);
        let net_state = self.state.load(std::sync::atomic::Ordering::SeqCst);

        if !net_state.is_command_allowed(packet.cmd_id) {
            tracing::warn!(
                "received cmd_id ({}) is not allowed in current state ({:?})",
                packet.cmd_id,
                self.state.load(std::sync::atomic::Ordering::SeqCst)
            );
        } else if net_state.is_auth() {
            if !handlers::handle_auth_request(self, &packet, state).await? {
                tracing::warn!(
                    "[LOGIN] packet with cmd_id={} wasn't handled, body: {}",
                    packet.cmd_id,
                    hex::encode(&packet.body)
                );
            }
        } else if !handlers::handle_request(self, &packet, state).await? {
            if !handlers::handle_notify(self, &packet, state).await? {
                tracing::warn!(
                    "packet with cmd_id={} wasn't handled, body: {}",
                    packet.cmd_id,
                    hex::encode(&packet.body)
                );
            }
        }

        Ok(())
    }

    async fn on_disconnect(&self, state: &ServerState) {
        state.session_mgr.remove(self.id);
        if let Some(player_uid) = self.player_uid.get() {
            state.player_mgr.save_and_remove(*player_uid).await;
        }
    }

    pub async fn notify(&self, mut ntf: impl NapMessage) -> Result<(), std::io::Error> {
        ntf.xor_fields();

        self.send(NetPacket {
            cmd_id: ntf.get_cmd_id(),
            head: PacketHead {
                packet_id: self.next_packet_id(),
                ..Default::default()
            },
            body: ntf.encode_to_vec().into_boxed_slice(),
        })
        .await
    }

    pub async fn send_rsp(
        &self,
        request_id: u32,
        mut rsp: impl NapMessage,
    ) -> Result<(), std::io::Error> {
        rsp.xor_fields();

        self.send(NetPacket {
            cmd_id: rsp.get_cmd_id(),
            head: PacketHead {
                packet_id: self.next_packet_id(),
                request_id,
                ..Default::default()
            },
            body: rsp.encode_to_vec().into_boxed_slice(),
        })
        .await
    }

    async fn send(&self, mut packet: NetPacket) -> Result<(), std::io::Error> {
        self.xor_payload(packet.cmd_id, &mut packet.body);

        let buf = packet.encode();
        self.writer.lock().await.write_all(&buf).await
    }

    pub fn id(&self) -> u64 {
        self.id
    }

    pub fn set_session_key(&self, seed: u64) {
        let _ = self.session_key.set(MhyXorpad::new::<BE>(seed));
    }

    pub fn account_uid(&self) -> Option<&String> {
        self.account_uid.get()
    }

    pub fn set_account_uid(&self, uid: String) -> bool {
        self.account_uid.set(uid).is_ok()
    }

    pub fn player_uid(&self) -> Option<&u32> {
        self.player_uid.get()
    }

    pub fn set_player_uid(&self, uid: u32) -> bool {
        self.player_uid.set(uid).is_ok()
    }

    fn xor_payload(&self, cmd_id: u16, buf: &mut [u8]) {
        let key = match self.session_key.get() {
            _ if cmd_id == proto::PlayerGetTokenScRsp::CMD_ID => &*SECRET_KEY,
            Some(key) => key,
            None => &*SECRET_KEY,
        };

        key.xor(buf);
    }

    fn next_packet_id(&self) -> u32 {
        self.packet_id_counter
            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
    }

    pub fn set_state(&self, state: NetSessionState) {
        self.state.store(state, std::sync::atomic::Ordering::SeqCst);
    }

    pub async fn shutdown(&self) -> Result<(), std::io::Error> {
        self.writer.lock().await.shutdown().await
    }
}
